library(MASS)
library(ISLR)
library(ElemStatLearn)
library(agtboost)
library(ggplot2)
setwd("/Users/eirik/OneDrive - University of Bergen")
source("functions.r")
Testing default model with sample size = 1 to see if it produces similar results as standard agtb
set.seed(1295)
seed = 1295
B = 10
seeds <- sample(1e5, B)
param <- list("learning_rate" = 0.1, "samSize" = 1, "nrounds"=1000)
agt_models = list() #List of all agt-models, agt_models[[i]][[b]] gives dataset i, iteration b
sam_models = list() #List of all sampling models
sam_test_loss = list() #List of all test loss for sampling models
agt_test_loss = list() #List of all test loss for agt models
agt_train_loss = list() #List of all train loss for agt models
agt_estgen_loss = list() #List of all estimated generalization loss for agt models
res <- list() #List of all mse of all models for all iterations on all datasets
for(i in 1:7){
pb <- txtProgressBar(min = 0, max = B*7, style = 3)
sub_models = list()
sub_agt_models = list()
sub_test_loss = list()
sub_agt_test_loss = list()
sub_agt_train_loss = list()
sub_agt_estgen_loss = list()
res_mat <- matrix(nrow=B, ncol=2)
for (b in 1:B){
#cat("iter: ", i,"\n")
dataset(i, seeds[b])
set.seed(seeds[b])
mod = sampling.agt.train(x.train, y.train, samSize = param$samSize, learnRate = param$learning_rate, Nrounds = param$nrounds,
verbose = 0, algorithm = "global_subset", type = "reg", force_continued_learning = F)
sub_models = c(sub_models, list(mod))
sam.pred = sampling.agt.pred(mod, x.test)
test_mse = sampling.agt.loss(mod, x.test, y.test)
sub_test_loss = c(sub_test_loss, list(test_mse))
set.seed(seeds[b])
agt.mod = gbt.train(y.train, x.train, param$learning_rate, nrounds = param$nrounds,
verbose = 0, algorithm = "global_subset")
sub_agt_models = c(sub_agt_models, list(agt.mod))
agt.pred = predict(agt.mod, x.test)
agt_test_mse = agt.loss(agt.mod, x.test, y.test)
sub_agt_test_loss = c(sub_agt_test_loss, list(agt_test_mse))
agt_train_mse = agt.loss(agt.mod, x.train, y.train)
sub_agt_train_loss = c(sub_agt_train_loss, list(agt_train_mse))
agt_est_gen_loss <- sapply(1:agt.mod$get_num_trees(), agt.mod$estimate_generalization_loss)
sub_agt_estgen_loss <- c(sub_agt_estgen_loss, list(agt_est_gen_loss))
res_mat[b, 1] <- loss_mse(y.test, sam.pred)
res_mat[b, 2] <- loss_mse(y.test, agt.pred)
res[[i]] <- res_mat
setTxtProgressBar(pb, ((i*10)+b-10))
}
sam_models = c(sam_models, list(sub_models))
agt_models = c(agt_models, list(sub_agt_models))
sam_test_loss = c(sam_test_loss, list(sub_test_loss))
agt_test_loss = c(agt_test_loss, list(sub_agt_test_loss))
agt_train_loss = c(agt_train_loss, list(sub_agt_train_loss))
agt_estgen_loss <- c(agt_estgen_loss, list(sub_agt_estgen_loss))
}
##
|
| | 0%
|
|= | 1%
|
|== | 3%
|
|=== | 4%
|
|==== | 6%
|
|===== | 7%
|
|====== | 9%
|
|======= | 10%
|
|======== | 11%
|
|========= | 13%
|
|========== | 14%
|
| | 0%
|
|=========== | 16%
|
|============ | 17%
|
|============= | 19%
|
|============== | 20%
|
|=============== | 21%
|
|================ | 23%
|
|================= | 24%
|
|================== | 26%
|
|=================== | 27%
|
|==================== | 29%
|
| | 0%
|
|===================== | 30%
|
|====================== | 31%
|
|======================= | 33%
|
|======================== | 34%
|
|========================= | 36%
|
|========================== | 37%
|
|=========================== | 39%
|
|============================ | 40%
|
|============================= | 41%
|
|============================== | 43%
|
| | 0%
|
|=============================== | 44%
|
|================================ | 46%
|
|================================= | 47%
|
|================================== | 49%
|
|=================================== | 50%
|
|==================================== | 51%
|
|===================================== | 53%
|
|====================================== | 54%
|
|======================================= | 56%
|
|======================================== | 57%
|
| | 0%
|
|========================================= | 59%
|
|========================================== | 60%
|
|=========================================== | 61%
|
|============================================ | 63%
|
|============================================= | 64%
|
|============================================== | 66%
|
|=============================================== | 67%
|
|================================================ | 69%
|
|================================================= | 70%
|
|================================================== | 71%
|
| | 0%
|
|=================================================== | 73%
|
|==================================================== | 74%
|
|===================================================== | 76%
|
|====================================================== | 77%
|
|======================================================= | 79%
|
|======================================================== | 80%
|
|========================================================= | 81%
|
|========================================================== | 83%
|
|=========================================================== | 84%
|
|============================================================ | 86%
|
| | 0%
|
|============================================================= | 87%
|
|============================================================== | 89%
|
|=============================================================== | 90%
|
|================================================================ | 91%
|
|================================================================= | 93%
|
|================================================================== | 94%
|
|=================================================================== | 96%
|
|==================================================================== | 97%
|
|===================================================================== | 99%
|
|======================================================================| 100%
Comparing train, test and estgen loss for agt and agt sampling to make sure they are similar.
The results are very similar, Hopefully it is the same model with some random noise.
sets = c(1,2,3,4,5,6,7) #dataset, between 1 and 7
iters = c(1, 5, 10) #seed, between 1 and 10
for(i in sets){
for(b in iters){
plot((sam_models[[i]][[b]]$train.loss), main=paste("dataset: ", i), cex = 1.5,
ylab = "Loss", xlab = "iter", type = "l")
lines((sam_test_loss[[i]][[b]]), col = "green", cex = 1.5)
lines((sam_models[[i]][[b]]$gen.loss), col = "red", cex = 1.5)
lines((agt_estgen_loss[[i]][[b]]),col = "brown", cex = 0.8)
lines((agt_test_loss[[i]][[b]]),col = "blue", cex = 0.8)
lines((agt_train_loss[[i]][[b]]),col = "pink", cex = 0.8)
legend( x="topright",
legend=c("Sam Train MSE","Sam Test MSE", "Sam Estgen loss", "Agt Estgen loss", "agt Test MSE", "agt Train MSE"),
col=c("black","green", "red", "brown", "blue", "pink"), lwd=1, lty=c(1))
}
}
training the default model for different sample sizes
set.seed(1295)
seed = 1295
B = 10
seeds <- sample(1e5, B)
param <- list("learning_rate" = 0.1, "samSize" = c(0.4,0.7,0.9), "nrounds"=1000)
sampled_sam_models = list() #List of all sampling models
sampled_sam_test_loss = list() #List of all test loss for sampling models
sampled_sam_res = list()
for(i in 1:7){
pb <- txtProgressBar(min = 0, max = B*7*length(param$samSize), style = 3)
sub_models = list()
sub_test_loss = list()
res_mat <- matrix(nrow=B, ncol=length(param$samSize))
for (b in 1:B){
#cat("iter: ", i,"\n")
dataset(i, seeds[b])
sub_sub_models = list()
sub_sub_test_loss = list()
j = 1
for (s in param$samSize){
set.seed(seeds[b])
mod = sampling.agt.train(x.train, y.train, samSize = s,
learnRate = param$learning_rate, Nrounds = param$nrounds,
force_continued_learning = F)
sub_sub_models = c(sub_sub_models, list(mod))
sam.pred = sampling.agt.pred(mod, x.test)
test_mse = sampling.agt.loss(mod, x.test, y.test)
sub_sub_test_loss = c(sub_sub_test_loss, list(test_mse))
res_mat[b,j] <- loss_mse(y.test, sam.pred)
j = j+1
setTxtProgressBar(pb, (((i-1)*B*length(param$samSize))+(b-1)*length(param$samSize)+j))
}
sub_models = c(sub_models, list(sub_sub_models))
sub_test_loss = c(sub_test_loss, list(sub_sub_test_loss))
sampled_sam_res[[i]] <- res_mat
}
sampled_sam_models = c(sampled_sam_models, list(sub_models))
sampled_sam_test_loss = c(sampled_sam_test_loss, list(sub_test_loss))
}
##
|
| | 0%
|
|= | 1%
|
|= | 2%
|
|== | 2%
|
|== | 3%
|
|=== | 4%
|
|=== | 5%
|
|==== | 5%
|
|==== | 6%
|
|===== | 7%
|
|===== | 8%
|
|====== | 8%
|
|====== | 9%
|
|======= | 10%
|
|======== | 11%
|
|======== | 12%
|
|========= | 12%
|
|========= | 13%
|
|========== | 14%
|
|========== | 15%
|
| | 0%
|
|=========== | 15%
|
|=========== | 16%
|
|============ | 17%
|
|============ | 18%
|
|============= | 18%
|
|============= | 19%
|
|============== | 20%
|
|=============== | 21%
|
|=============== | 22%
|
|================ | 22%
|
|================ | 23%
|
|================= | 24%
|
|================= | 25%
|
|================== | 25%
|
|================== | 26%
|
|=================== | 27%
|
|=================== | 28%
|
|==================== | 28%
|
|==================== | 29%
|
| | 0%
|
|===================== | 30%
|
|====================== | 31%
|
|====================== | 32%
|
|======================= | 32%
|
|======================= | 33%
|
|======================== | 34%
|
|======================== | 35%
|
|========================= | 35%
|
|========================= | 36%
|
|========================== | 37%
|
|========================== | 38%
|
|=========================== | 38%
|
|=========================== | 39%
|
|============================ | 40%
|
|============================= | 41%
|
|============================= | 42%
|
|============================== | 42%
|
|============================== | 43%
|
| | 0%
|
|=============================== | 44%
|
|=============================== | 45%
|
|================================ | 45%
|
|================================ | 46%
|
|================================= | 47%
|
|================================= | 48%
|
|================================== | 48%
|
|================================== | 49%
|
|=================================== | 50%
|
|==================================== | 51%
|
|==================================== | 52%
|
|===================================== | 52%
|
|===================================== | 53%
|
|====================================== | 54%
|
|====================================== | 55%
|
|======================================= | 55%
|
|======================================= | 56%
|
|======================================== | 57%
|
|======================================== | 58%
|
| | 0%
|
|========================================= | 58%
|
|========================================= | 59%
|
|========================================== | 60%
|
|=========================================== | 61%
|
|=========================================== | 62%
|
|============================================ | 62%
|
|============================================ | 63%
|
|============================================= | 64%
|
|============================================= | 65%
|
|============================================== | 65%
|
|============================================== | 66%
|
|=============================================== | 67%
|
|=============================================== | 68%
|
|================================================ | 68%
|
|================================================ | 69%
|
|================================================= | 70%
|
|================================================== | 71%
|
|================================================== | 72%
|
| | 0%
|
|=================================================== | 72%
|
|=================================================== | 73%
|
|==================================================== | 74%
|
|==================================================== | 75%
|
|===================================================== | 75%
|
|===================================================== | 76%
|
|====================================================== | 77%
|
|====================================================== | 78%
|
|======================================================= | 78%
|
|======================================================= | 79%
|
|======================================================== | 80%
|
|========================================================= | 81%
|
|========================================================= | 82%
|
|========================================================== | 82%
|
|========================================================== | 83%
|
|=========================================================== | 84%
|
|=========================================================== | 85%
|
|============================================================ | 85%
|
|============================================================ | 86%
|
| | 0%
|
|============================================================= | 87%
|
|============================================================= | 88%
|
|============================================================== | 88%
|
|============================================================== | 89%
|
|=============================================================== | 90%
|
|================================================================ | 91%
|
|================================================================ | 92%
|
|================================================================= | 92%
|
|================================================================= | 93%
|
|================================================================== | 94%
|
|================================================================== | 95%
|
|=================================================================== | 95%
|
|=================================================================== | 96%
|
|==================================================================== | 97%
|
|==================================================================== | 98%
|
|===================================================================== | 98%
|
|===================================================================== | 99%
|
|======================================================================| 100%
We see that for the default function with sampling size less than 1, we have too much variance in the estimated generalization loss, so the ensembles stop way too early
plot.test.loss <- function(test_loss){
sets = c(1,2,3,4,5,6,7) #dataset, between 1 and 7
iters = c(1,2,3) #seed, between 1 and 10
start = 50
for(i in sets){
for(b in iters){
plot((test_loss[[i]][[b]][[3]]), col = "purple", cex = 0.8, type = "l",
main = paste("Test loss for Dataset: ", i, " Iteration: ", b))
lines(((test_loss[[i]][[b]][[2]])), col = "blue", cex = 0.8)
lines(((test_loss[[i]][[b]][[1]])), col = "green", cex = 0.8)
lines((agt_test_loss[[i]][[b]]), col = "brown", cex = 1.2)
legend( x="topright",
legend=c("AGT", "SAM(0.4)", "SAM(0.7)", "SAM(0.9"),
col=c("brown","green", "blue", "purple"), lwd=1, lty=c(1))
}
}
}
plot.test.loss(sampled_sam_test_loss)
plot.gen.loss <- function(models){
sets = c(1,2,3,4,5,6,7) #dataset, between 1 and 7
iters = c(1,2,3) #seed, between 1 and 10
start = 50
for(i in sets){
for(b in iters){
plot((models[[i]][[b]][[3]]$gen.loss), col = "purple", cex = 0.8, type = "l",
main = paste("gen loss for Dataset: ", i, " Iteration: ", b))
lines((models[[i]][[b]][[2]]$gen.loss), col = "blue", cex = 0.8)
lines((models[[i]][[b]][[1]]$gen.loss), col = "green", cex = 0.8)
legend( x="topright",
legend=c("SAM(0.4)", "SAM(0.7)","SAM(0.9)"),
col=c("green", "blue", "purple"), lwd=1, lty=c(1))
}
}
}
plot.gen.loss(sampled_sam_models)
#First alternative Function
set.seed(1295)
seed = 1295
B = 10
seeds <- sample(1e5, B)
param <- list("learning_rate" = 0.1, "samSize" = c(0.4, 0.7, 0.9), "nrounds"=1000)
first_sam_models = list() #List of all sampling models
first_sam_test_loss = list() #List of all test loss for sampling models
first_sam_res = list()
for(i in 1:7){
pb <- txtProgressBar(min = 0, max = B*7*length(param$samSize), style = 3)
sub_models = list()
sub_test_loss = list()
res_mat <- matrix(nrow=B, ncol=length(param$samSize))
for (b in 1:B){
#cat("iter: ", i,"\n")
dataset(i, seeds[b])
sub_sub_models = list()
sub_sub_test_loss = list()
j = 1
for (s in param$samSize){
set.seed(seeds[b])
mod = first.sampling.agt.train(x.train, y.train, samSize = s,
learnRate = param$learning_rate, Nrounds = param$nrounds,
force_continued_learning = F)
sub_sub_models = c(sub_sub_models, list(mod))
sam.pred = sampling.agt.pred(mod, x.test)
test_mse = sampling.agt.loss(mod, x.test, y.test)
sub_sub_test_loss = c(sub_sub_test_loss, list(test_mse))
res_mat[b,j] <- loss_mse(y.test, sam.pred)
j = j+1
setTxtProgressBar(pb, (((i-1)*B*length(param$samSize))+(b-1)*length(param$samSize)+j))
}
sub_models = c(sub_models, list(sub_sub_models))
sub_test_loss = c(sub_test_loss, list(sub_sub_test_loss))
first_sam_res[[i]] <- res_mat
}
first_sam_models = c(first_sam_models, list(sub_models))
first_sam_test_loss = c(first_sam_test_loss, list(sub_test_loss))
}
##
|
| | 0%
|
|= | 1%
|
|= | 2%
|
|== | 2%
|
|== | 3%
|
|=== | 4%
|
|=== | 5%
|
|==== | 5%
|
|==== | 6%
|
|===== | 7%
|
|===== | 8%
|
|====== | 8%
|
|====== | 9%
|
|======= | 10%
|
|======== | 11%
|
|======== | 12%
|
|========= | 12%
|
|========= | 13%
|
|========== | 14%
|
|========== | 15%
|
| | 0%
|
|=========== | 15%
|
|=========== | 16%
|
|============ | 17%
|
|============ | 18%
|
|============= | 18%
|
|============= | 19%
|
|============== | 20%
|
|=============== | 21%
|
|=============== | 22%
|
|================ | 22%
|
|================ | 23%
|
|================= | 24%
|
|================= | 25%
|
|================== | 25%
|
|================== | 26%
|
|=================== | 27%
|
|=================== | 28%
|
|==================== | 28%
|
|==================== | 29%
|
| | 0%
|
|===================== | 30%
|
|====================== | 31%
|
|====================== | 32%
|
|======================= | 32%
|
|======================= | 33%
|
|======================== | 34%
|
|======================== | 35%
|
|========================= | 35%
|
|========================= | 36%
|
|========================== | 37%
|
|========================== | 38%
|
|=========================== | 38%
|
|=========================== | 39%
|
|============================ | 40%
|
|============================= | 41%
|
|============================= | 42%
|
|============================== | 42%
|
|============================== | 43%
|
| | 0%
|
|=============================== | 44%
|
|=============================== | 45%
|
|================================ | 45%
|
|================================ | 46%
|
|================================= | 47%
|
|================================= | 48%
|
|================================== | 48%
|
|================================== | 49%
|
|=================================== | 50%
|
|==================================== | 51%
|
|==================================== | 52%
|
|===================================== | 52%
|
|===================================== | 53%
|
|====================================== | 54%
|
|====================================== | 55%
|
|======================================= | 55%
|
|======================================= | 56%
|
|======================================== | 57%
|
|======================================== | 58%
|
| | 0%
|
|========================================= | 58%
|
|========================================= | 59%
|
|========================================== | 60%
|
|=========================================== | 61%
|
|=========================================== | 62%
|
|============================================ | 62%
|
|============================================ | 63%
|
|============================================= | 64%
|
|============================================= | 65%
|
|============================================== | 65%
|
|============================================== | 66%
|
|=============================================== | 67%
|
|=============================================== | 68%
|
|================================================ | 68%
|
|================================================ | 69%
|
|================================================= | 70%
|
|================================================== | 71%
|
|================================================== | 72%
|
| | 0%
|
|=================================================== | 72%
|
|=================================================== | 73%
|
|==================================================== | 74%
|
|==================================================== | 75%
|
|===================================================== | 75%
|
|===================================================== | 76%
|
|====================================================== | 77%
|
|====================================================== | 78%
|
|======================================================= | 78%
|
|======================================================= | 79%
|
|======================================================== | 80%
|
|========================================================= | 81%
|
|========================================================= | 82%
|
|========================================================== | 82%
|
|========================================================== | 83%
|
|=========================================================== | 84%
|
|=========================================================== | 85%
|
|============================================================ | 85%
|
|============================================================ | 86%
|
| | 0%
|
|============================================================= | 87%
|
|============================================================= | 88%
|
|============================================================== | 88%
|
|============================================================== | 89%
|
|=============================================================== | 90%
|
|================================================================ | 91%
|
|================================================================ | 92%
|
|================================================================= | 92%
|
|================================================================= | 93%
|
|================================================================== | 94%
|
|================================================================== | 95%
|
|=================================================================== | 95%
|
|=================================================================== | 96%
|
|==================================================================== | 97%
|
|==================================================================== | 98%
|
|===================================================================== | 98%
|
|===================================================================== | 99%
|
|======================================================================| 100%
Not very promising
plot.test.loss(first_sam_test_loss)
plot.results <- function(results){
sets = c(1,2,3,4,5,6,7) #dataset, between 1 and 7
plot(1,1, xlim = c(0.35, 0.95), ylim = c(0.75,1.5),
main = "mean Test loss relative to AGT for different samSize on each dataset",
ylab = "loss", xlab = "sampling size")
for(i in sets){
lines(x = c(0.4,0.7,0.9), y = colMeans(results[[i]])/mean(res[[i]][,2]), col = i, xlim = c(0.3, 1))
abline(h=1)
legend( x="topright",
legend=c("1", "2","3", "4", "5", "6", "7"),
col=c(1,2,3,4,5,6,7), lwd=1, lty=c(1))
}
}
plot.results(first_sam_res)
plot.gen.loss(first_sam_models)
Stops if the min estgenloss is not in the last a iterations
set.seed(1295)
seed = 1295
B = 10
seeds <- sample(1e5, B)
param <- list("learning_rate" = 0.1, "samSize" = c(0.4,0.7,0.9), "nrounds"=1000)
min_sam_models = list() #List of all sampling models
min_sam_test_loss = list() #List of all test loss for sampling models
min_sam_res = list()
for(i in 1:7){
pb <- txtProgressBar(min = 0, max = B*7*length(param$samSize), style = 3)
sub_models = list()
sub_test_loss = list()
res_mat <- matrix(nrow=B, ncol=length(param$samSize))
for (b in 1:B){
#cat("iter: ", i,"\n")
dataset(i, seeds[b])
sub_sub_models = list()
sub_sub_test_loss = list()
j = 1
for (s in param$samSize){
set.seed(seeds[b])
mod = min.sampling.agt.train(x.train, y.train, samSize = s, a=50,
learnRate = param$learning_rate, Nrounds = param$nrounds,
force_continued_learning = F)
sub_sub_models = c(sub_sub_models, list(mod))
sam.pred = sampling.agt.pred(mod, x.test)
test_mse = sampling.agt.loss(mod, x.test, y.test)
sub_sub_test_loss = c(sub_sub_test_loss, list(test_mse))
res_mat[b,j] <- loss_mse(y.test, sam.pred)
j = j+1
setTxtProgressBar(pb, (((i-1)*B*length(param$samSize))+(b-1)*length(param$samSize)+j))
}
sub_models = c(sub_models, list(sub_sub_models))
sub_test_loss = c(sub_test_loss, list(sub_sub_test_loss))
min_sam_res[[i]] <- res_mat
}
min_sam_models = c(min_sam_models, list(sub_models))
min_sam_test_loss = c(min_sam_test_loss, list(sub_test_loss))
}
##
|
| | 0%
|
|= | 1%
|
|= | 2%
|
|== | 2%
|
|== | 3%
|
|=== | 4%
|
|=== | 5%
|
|==== | 5%
|
|==== | 6%
|
|===== | 7%
|
|===== | 8%
|
|====== | 8%
|
|====== | 9%
|
|======= | 10%
|
|======== | 11%
|
|======== | 12%
|
|========= | 12%
|
|========= | 13%
|
|========== | 14%
|
|========== | 15%
|
| | 0%
|
|=========== | 15%
|
|=========== | 16%
|
|============ | 17%
|
|============ | 18%
|
|============= | 18%
|
|============= | 19%
|
|============== | 20%
|
|=============== | 21%
|
|=============== | 22%
|
|================ | 22%
|
|================ | 23%
|
|================= | 24%
|
|================= | 25%
|
|================== | 25%
|
|================== | 26%
|
|=================== | 27%
|
|=================== | 28%
|
|==================== | 28%
|
|==================== | 29%
|
| | 0%
|
|===================== | 30%
|
|====================== | 31%
|
|====================== | 32%
|
|======================= | 32%
|
|======================= | 33%
|
|======================== | 34%
|
|======================== | 35%
|
|========================= | 35%
|
|========================= | 36%
|
|========================== | 37%
|
|========================== | 38%
|
|=========================== | 38%
|
|=========================== | 39%
|
|============================ | 40%
|
|============================= | 41%
|
|============================= | 42%
|
|============================== | 42%
|
|============================== | 43%
|
| | 0%
|
|=============================== | 44%
|
|=============================== | 45%
|
|================================ | 45%
|
|================================ | 46%
|
|================================= | 47%
|
|================================= | 48%
|
|================================== | 48%
|
|================================== | 49%
|
|=================================== | 50%
|
|==================================== | 51%
|
|==================================== | 52%
|
|===================================== | 52%
|
|===================================== | 53%
|
|====================================== | 54%
|
|====================================== | 55%
|
|======================================= | 55%
|
|======================================= | 56%
|
|======================================== | 57%
|
|======================================== | 58%
|
| | 0%
|
|========================================= | 58%
|
|========================================= | 59%
|
|========================================== | 60%
|
|=========================================== | 61%
|
|=========================================== | 62%
|
|============================================ | 62%
|
|============================================ | 63%
|
|============================================= | 64%
|
|============================================= | 65%
|
|============================================== | 65%
|
|============================================== | 66%
|
|=============================================== | 67%
|
|=============================================== | 68%
|
|================================================ | 68%
|
|================================================ | 69%
|
|================================================= | 70%
|
|================================================== | 71%
|
|================================================== | 72%
|
| | 0%
|
|=================================================== | 72%
|
|=================================================== | 73%
|
|==================================================== | 74%
|
|==================================================== | 75%
|
|===================================================== | 75%
|
|===================================================== | 76%
|
|====================================================== | 77%
|
|====================================================== | 78%
|
|======================================================= | 78%
|
|======================================================= | 79%
|
|======================================================== | 80%
|
|========================================================= | 81%
|
|========================================================= | 82%
|
|========================================================== | 82%
|
|========================================================== | 83%
|
|=========================================================== | 84%
|
|=========================================================== | 85%
|
|============================================================ | 85%
|
|============================================================ | 86%
|
| | 0%
|
|============================================================= | 87%
|
|============================================================= | 88%
|
|============================================================== | 88%
|
|============================================================== | 89%
|
|=============================================================== | 90%
|
|================================================================ | 91%
|
|================================================================ | 92%
|
|================================================================= | 92%
|
|================================================================= | 93%
|
|================================================================== | 94%
|
|================================================================== | 95%
|
|=================================================================== | 95%
|
|=================================================================== | 96%
|
|==================================================================== | 97%
|
|==================================================================== | 98%
|
|===================================================================== | 98%
|
|===================================================================== | 99%
|
|======================================================================| 100%
The results are good, but I am including 50 trees after the minimum estgenloss, so the method is not very good. potential bug
plot.test.loss(min_sam_test_loss)
plot.results(min_sam_res)
plot.gen.loss(min_sam_models)
### plotting training loss
sets = c(1,2,3,4,5,6,7) #dataset, between 1 and 7
iters = c(1,2,3) #seed, between 1 and 10
start = 50
for(i in sets){
for(b in iters){
plot((min_sam_models[[i]][[b]][[3]]$train.loss), col = "purple", cex = 0.8, type = "l",
main = paste("Train loss for Dataset: ", i, " Iteration: ", b))
lines((min_sam_models[[i]][[b]][[2]]$train.loss), col = "blue", cex = 0.8)
lines((min_sam_models[[i]][[b]][[1]]$train.loss), col = "green", cex = 0.8)
legend( x="topright",
legend=c("SAM(0.4)", "SAM(0.7)","SAM(0.9)"),
col=c("green", "blue", "purple"), lwd=1, lty=c(1))
}
}
set.seed(1295)
seed = 1295
B = 10
seeds <- sample(1e5, B)
param <- list("learning_rate" = 0.1, "samSize" = c(0.4,0.7,0.9), "nrounds"=1000)
full_sam_models = list() #List of all sampling models
full_sam_test_loss = list() #List of all test loss for sampling models
full_sam_res = list()
for(i in 1:7){
pb <- txtProgressBar(min = 0, max = B*7*length(param$samSize), style = 3)
sub_models = list()
sub_test_loss = list()
res_mat <- matrix(nrow=B, ncol=length(param$samSize))
for (b in 1:B){
#cat("iter: ", i,"\n")
dataset(i, seeds[b])
sub_sub_models = list()
sub_sub_test_loss = list()
j = 1
for (s in param$samSize){
set.seed(seeds[b])
mod = full.sampling.agt.train(x.train, y.train, samSize = s,
learnRate = param$learning_rate, Nrounds = param$nrounds,
force_continued_learning = F)
sub_sub_models = c(sub_sub_models, list(mod))
sam.pred = sampling.agt.pred(mod, x.test)
test_mse = sampling.agt.loss(mod, x.test, y.test)
sub_sub_test_loss = c(sub_sub_test_loss, list(test_mse))
res_mat[b,j] <- loss_mse(y.test, sam.pred)
j = j+1
setTxtProgressBar(pb, (((i-1)*B*length(param$samSize))+(b-1)*length(param$samSize)+j))
}
sub_models = c(sub_models, list(sub_sub_models))
sub_test_loss = c(sub_test_loss, list(sub_sub_test_loss))
full_sam_res[[i]] <- res_mat
}
full_sam_models = c(full_sam_models, list(sub_models))
full_sam_test_loss = c(full_sam_test_loss, list(sub_test_loss))
}
##
|
| | 0%
|
|= | 1%
|
|= | 2%
|
|== | 2%
|
|== | 3%
|
|=== | 4%
|
|=== | 5%
|
|==== | 5%
|
|==== | 6%
|
|===== | 7%
|
|===== | 8%
|
|====== | 8%
|
|====== | 9%
|
|======= | 10%
|
|======== | 11%
|
|======== | 12%
|
|========= | 12%
|
|========= | 13%
|
|========== | 14%
|
|========== | 15%
|
| | 0%
|
|=========== | 15%
|
|=========== | 16%
|
|============ | 17%
|
|============ | 18%
|
|============= | 18%
|
|============= | 19%
|
|============== | 20%
|
|=============== | 21%
|
|=============== | 22%
|
|================ | 22%
|
|================ | 23%
|
|================= | 24%
|
|================= | 25%
|
|================== | 25%
|
|================== | 26%
|
|=================== | 27%
|
|=================== | 28%
|
|==================== | 28%
|
|==================== | 29%
|
| | 0%
|
|===================== | 30%
|
|====================== | 31%
|
|====================== | 32%
|
|======================= | 32%
|
|======================= | 33%
|
|======================== | 34%
|
|======================== | 35%
|
|========================= | 35%
|
|========================= | 36%
|
|========================== | 37%
|
|========================== | 38%
|
|=========================== | 38%
|
|=========================== | 39%
|
|============================ | 40%
|
|============================= | 41%
|
|============================= | 42%
|
|============================== | 42%
|
|============================== | 43%
|
| | 0%
|
|=============================== | 44%
|
|=============================== | 45%
|
|================================ | 45%
|
|================================ | 46%
|
|================================= | 47%
|
|================================= | 48%
|
|================================== | 48%
|
|================================== | 49%
|
|=================================== | 50%
|
|==================================== | 51%
|
|==================================== | 52%
|
|===================================== | 52%
|
|===================================== | 53%
|
|====================================== | 54%
|
|====================================== | 55%
|
|======================================= | 55%
|
|======================================= | 56%
|
|======================================== | 57%
|
|======================================== | 58%
|
| | 0%
|
|========================================= | 58%
|
|========================================= | 59%
|
|========================================== | 60%
|
|=========================================== | 61%
|
|=========================================== | 62%
|
|============================================ | 62%
|
|============================================ | 63%
|
|============================================= | 64%
|
|============================================= | 65%
|
|============================================== | 65%
|
|============================================== | 66%
|
|=============================================== | 67%
|
|=============================================== | 68%
|
|================================================ | 68%
|
|================================================ | 69%
|
|================================================= | 70%
|
|================================================== | 71%
|
|================================================== | 72%
|
| | 0%
|
|=================================================== | 72%
|
|=================================================== | 73%
|
|==================================================== | 74%
|
|==================================================== | 75%
|
|===================================================== | 75%
|
|===================================================== | 76%
|
|====================================================== | 77%
|
|====================================================== | 78%
|
|======================================================= | 78%
|
|======================================================= | 79%
|
|======================================================== | 80%
|
|========================================================= | 81%
|
|========================================================= | 82%
|
|========================================================== | 82%
|
|========================================================== | 83%
|
|=========================================================== | 84%
|
|=========================================================== | 85%
|
|============================================================ | 85%
|
|============================================================ | 86%
|
| | 0%
|
|============================================================= | 87%
|
|============================================================= | 88%
|
|============================================================== | 88%
|
|============================================================== | 89%
|
|=============================================================== | 90%
|
|================================================================ | 91%
|
|================================================================ | 92%
|
|================================================================= | 92%
|
|================================================================= | 93%
|
|================================================================== | 94%
|
|================================================================== | 95%
|
|=================================================================== | 95%
|
|=================================================================== | 96%
|
|==================================================================== | 97%
|
|==================================================================== | 98%
|
|===================================================================== | 98%
|
|===================================================================== | 99%
|
|======================================================================| 100%
plot.test.loss(full_sam_test_loss)
plot.results(full_sam_res)
gen loss is off here because of a fixed bug
plot.gen.loss(full_sam_models)
set.seed(1295)
seed = 1295
B = 10
seeds <- sample(1e5, B)
param <- list("learning_rate" = 0.1, "samSize" = c(0.4,0.7,0.9), "nrounds"=1000)
same_sam_models = list() #List of all sampling models
same_sam_test_loss = list() #List of all test loss for sampling models
same_sam_res = list()
for(i in 1:7){
pb <- txtProgressBar(min = 0, max = B*7*length(param$samSize), style = 3)
sub_models = list()
sub_test_loss = list()
res_mat <- matrix(nrow=B, ncol=length(param$samSize))
for (b in 1:B){
#cat("iter: ", i,"\n")
dataset(i, seeds[b])
sub_sub_models = list()
sub_sub_test_loss = list()
j = 1
for (s in param$samSize){
set.seed(seeds[b])
mod = same.sampling.agt.train(x.train, y.train, samSize = s,
learnRate = param$learning_rate, Nrounds = param$nrounds,
force_continued_learning = F)
sub_sub_models = c(sub_sub_models, list(mod))
sam.pred = sampling.agt.pred(mod, x.test)
test_mse = sampling.agt.loss(mod, x.test, y.test)
sub_sub_test_loss = c(sub_sub_test_loss, list(test_mse))
res_mat[b,j] <- loss_mse(y.test, sam.pred)
j = j+1
setTxtProgressBar(pb, (((i-1)*B*length(param$samSize))+(b-1)*length(param$samSize)+j))
}
sub_models = c(sub_models, list(sub_sub_models))
sub_test_loss = c(sub_test_loss, list(sub_sub_test_loss))
same_sam_res[[i]] <- res_mat
}
same_sam_models = c(same_sam_models, list(sub_models))
same_sam_test_loss = c(same_sam_test_loss, list(sub_test_loss))
}
##
|
| | 0%
|
|= | 1%
|
|= | 2%
|
|== | 2%
|
|== | 3%
|
|=== | 4%
|
|=== | 5%
|
|==== | 5%
|
|==== | 6%
|
|===== | 7%
|
|===== | 8%
|
|====== | 8%
|
|====== | 9%
|
|======= | 10%
|
|======== | 11%
|
|======== | 12%
|
|========= | 12%
|
|========= | 13%
|
|========== | 14%
|
|========== | 15%
|
| | 0%
|
|=========== | 15%
|
|=========== | 16%
|
|============ | 17%
|
|============ | 18%
|
|============= | 18%
|
|============= | 19%
|
|============== | 20%
|
|=============== | 21%
|
|=============== | 22%
|
|================ | 22%
|
|================ | 23%
|
|================= | 24%
|
|================= | 25%
|
|================== | 25%
|
|================== | 26%
|
|=================== | 27%
|
|=================== | 28%
|
|==================== | 28%
|
|==================== | 29%
|
| | 0%
|
|===================== | 30%
|
|====================== | 31%
|
|====================== | 32%
|
|======================= | 32%
|
|======================= | 33%
|
|======================== | 34%
|
|======================== | 35%
|
|========================= | 35%
|
|========================= | 36%
|
|========================== | 37%
|
|========================== | 38%
|
|=========================== | 38%
|
|=========================== | 39%
|
|============================ | 40%
|
|============================= | 41%
|
|============================= | 42%
|
|============================== | 42%
|
|============================== | 43%
|
| | 0%
|
|=============================== | 44%
|
|=============================== | 45%
|
|================================ | 45%
|
|================================ | 46%
|
|================================= | 47%
|
|================================= | 48%
|
|================================== | 48%
|
|================================== | 49%
|
|=================================== | 50%
|
|==================================== | 51%
|
|==================================== | 52%
|
|===================================== | 52%
|
|===================================== | 53%
|
|====================================== | 54%
|
|====================================== | 55%
|
|======================================= | 55%
|
|======================================= | 56%
|
|======================================== | 57%
|
|======================================== | 58%
|
| | 0%
|
|========================================= | 58%
|
|========================================= | 59%
|
|========================================== | 60%
|
|=========================================== | 61%
|
|=========================================== | 62%
|
|============================================ | 62%
|
|============================================ | 63%
|
|============================================= | 64%
|
|============================================= | 65%
|
|============================================== | 65%
|
|============================================== | 66%
|
|=============================================== | 67%
|
|=============================================== | 68%
|
|================================================ | 68%
|
|================================================ | 69%
|
|================================================= | 70%
|
|================================================== | 71%
|
|================================================== | 72%
|
| | 0%
|
|=================================================== | 72%
|
|=================================================== | 73%
|
|==================================================== | 74%
|
|==================================================== | 75%
|
|===================================================== | 75%
|
|===================================================== | 76%
|
|====================================================== | 77%
|
|====================================================== | 78%
|
|======================================================= | 78%
|
|======================================================= | 79%
|
|======================================================== | 80%
|
|========================================================= | 81%
|
|========================================================= | 82%
|
|========================================================== | 82%
|
|========================================================== | 83%
|
|=========================================================== | 84%
|
|=========================================================== | 85%
|
|============================================================ | 85%
|
|============================================================ | 86%
|
| | 0%
|
|============================================================= | 87%
|
|============================================================= | 88%
|
|============================================================== | 88%
|
|============================================================== | 89%
|
|=============================================================== | 90%
|
|================================================================ | 91%
|
|================================================================ | 92%
|
|================================================================= | 92%
|
|================================================================= | 93%
|
|================================================================== | 94%
|
|================================================================== | 95%
|
|=================================================================== | 95%
|
|=================================================================== | 96%
|
|==================================================================== | 97%
|
|==================================================================== | 98%
|
|===================================================================== | 98%
|
|===================================================================== | 99%
|
|======================================================================| 100%
plot.test.loss(same_sam_test_loss)
plot.results(same_sam_res)
plot.gen.loss(same_sam_models)